#!/usr/bin/env python3
# H23 — Molecular Multi-Center Cohesion (L2 bond-graph edges for 2–3 centers)
#
# CONTROL (unchanged, theory-faithful):
#   • Present-act, boolean/ordinal. No weights, no potentials, no RNG in control.
#   • DDA 1/r per center and per integer shell: A[r] += rate_num; if A[r] >= r, fire shell r and set A[r] -= r.
#   • Whole-shell firing, rails OFF (as in D9) to avoid parity aliasing; neighbor requirements are not needed here.
#
# DIAGNOSTICS ONLY:
#   • Shared interior radial window r ∈ [r_band_min, r_band_max] across all centers.
#   • At each tick, for each center, collect the list of fired radii within the band.
#   • A pair (i,j) exhibits an L2 “contact” event at tick t if ∃ r∈R_i(t), s∈R_j(t) with circle-intersection:
#         |r − s| ≤ d_ij ≤ r + s   (d_ij = center-to-center distance, integer Euclidean)
#   • The pair’s bond strength is the contact rate:  contacts_ij / H.
#   • Acceptance requires required edges to exceed edge_rate_min (and optional triad rate if 3 centers).
#
# OUTPUT:
#   • metrics/h23_edges.csv — per-edge contact counts and rates, per-center band coverage, optional triad rate
#   • audits/h23_audit.json — full metrics + acceptance decisions
#   • run_info/result_line.txt — one-line PASS summary
#
# NOTE: All time/space decisions are integer / boolean; only diagnostics compute floats for reporting thresholds.

import argparse, json, math, os, csv, sys
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict, Tuple

# ------------------- utils -------------------
def utc_timestamp():
    import datetime as _dt
    return _dt.datetime.now(_dt.timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root: str, *subs: str):
    for s in subs:
        os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path: str, text: str):
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def dump_json(path: str, obj: dict):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

# ------------------- geometry -------------------
def isqrt(n: int) -> int:
    return int(math.isqrt(n))

def edge_radius(cx: int, cy: int, N: int) -> int:
    """Largest integer r so that all points with floor(sqrt((x-cx)^2+(y-cy)^2)) <= r lie inside grid."""
    return min(cx, cy, (N - 1 - cx), (N - 1 - cy))

def dist2(a: Tuple[int,int], b: Tuple[int,int]) -> int:
    return (a[0]-b[0])**2 + (a[1]-b[1])**2

# ------------------- present-act DDA 1/r -------------------
@dataclass
class Center:
    cx: int
    cy: int
    rmax: int
    rate: int
    accum: List[int]  # length rmax+1, index by r

def build_centers(N: int, centers_cfg: List[Dict], rate_num: int, outer_margin: int) -> List[Center]:
    centers: List[Center] = []
    # shared interior rmax across centers (for band/window consistency)
    rmax_shared = min(edge_radius(c["cx"], c["cy"], N) for c in centers_cfg) - int(outer_margin)
    if rmax_shared < 4:
        raise SystemExit("rmax too small; increase grid or reduce outer_margin")
    for c in centers_cfg:
        centers.append(Center(c["cx"], c["cy"], rmax_shared, rate_num, [0]*(rmax_shared+1)))
    return centers

def fire_radii(center: Center, r_lo: int, r_hi: int) -> List[int]:
    """Advance DDA one tick over all shells up to rmax; return list of fired radii within [r_lo,r_hi]."""
    fired=[]
    for r in range(1, center.rmax+1):
        center.accum[r] += center.rate
        if center.accum[r] >= r:
            center.accum[r] -= r
            if r_lo <= r <= r_hi:
                fired.append(r)
    return fired

# ------------------- L2 contact detection -------------------
def rings_intersect(r1: int, r2: int, dij: float) -> bool:
    # intersection condition for circles of radii r1,r2 at distance d
    return abs(r1 - r2) <= dij <= (r1 + r2)

def contact_event(fired_i: List[int], fired_j: List[int], dij: float) -> bool:
    # any pair of fired radii intersect?
    for ri in fired_i:
        # cheap pruning by triangle inequality: s must be in [|d-ri|, d+ri]
        lo = int(max(1, math.ceil(dij - ri)))
        hi = int(math.floor(dij + ri))
        for rj in fired_j:
            if lo <= rj <= hi and rings_intersect(ri, rj, dij):
                return True
    return False

# ------------------- main sim -------------------
def run_sim(M: dict, outdir: str) -> Dict:
    N = int(M["grid"]["N"])
    centers_cfg = M["centers"]
    H = int(M["H"])
    rate_num = int(M["control"]["rate_num"])
    outer_margin = int(M["outer_margin"])
    band_lo = int(M["band"]["r_min"])
    band_hi = int(M["band"]["r_max"])
    assert 1 <= band_lo < band_hi, "Invalid band"

    centers = build_centers(N, centers_cfg, rate_num, outer_margin)
    nC = len(centers)
    if not (2 <= nC <= 3):
        raise SystemExit("H23 expects 2 or 3 centers")

    # pair indices
    pairs=[]
    for i in range(nC):
        for j in range(i+1, nC):
            pairs.append((i,j))

    # center distances
    dij = {(i,j): math.sqrt((centers[i].cx - centers[j].cx)**2 + (centers[i].cy - centers[j].cy)**2)
           for (i,j) in pairs}

    # per-tick contact flags and counts
    contacts = {p: 0 for p in pairs}
    coverage = [0]*nC     # ticks with any fired radius in band
    triad_count = 0

    for t in range(H):
        fired = []
        for c in centers:
            fr = fire_radii(c, band_lo, band_hi)
            fired.append(fr)
        for idx,cfr in enumerate(fired):
            if cfr: 
                coverage[idx] += 1

        # pairwise contact
        tick_edge = {}
        for (i,j) in pairs:
            evt = contact_event(fired[i], fired[j], dij[(i,j)])
            if evt:
                contacts[(i,j)] += 1
                tick_edge[(i,j)] = True
            else:
                tick_edge[(i,j)] = False

        # triad event if all three pairwise edges fire in same tick (only when nC==3)
        if nC == 3:
            (i0,i1),(i0i2),(i1i2) = pairs[0], pairs[1], pairs[2]
            if tick_edge.get(pairs[0],False) and tick_edge.get(pairs[1],False) and tick_edge.get(pairs[2],False):
                triad_count += 1

    # rates
    contact_rate = {f"{i}-{j}": contacts[(i,j)]/H for (i,j) in pairs}
    coverage_rate = [cov/H for cov in coverage]
    triad_rate = triad_count/H if nC==3 else 0.0

    # acceptance
    acc = M["acceptance"]
    edge_min = float(acc["edge_rate_min"])
    cover_min = float(acc.get("coverage_rate_min", 0.0))
    require_clique = bool(acc.get("require_clique", False))
    edges_required = acc.get("edges_required", [])  # list of [i,j] if not using clique

    # which edges must hold
    required_pairs = []
    if require_clique:
        required_pairs = pairs
    else:
        for ij in edges_required:
            required_pairs.append((int(ij[0]), int(ij[1])))

    edges_ok = True
    for (i,j) in required_pairs:
        if contact_rate[f"{i}-{j}"] + 1e-12 < edge_min:
            edges_ok = False
            break

    coverage_ok = all(cr >= cover_min for cr in coverage_rate)
    triad_ok = True
    if nC == 3 and "triad_rate_min" in acc:
        triad_ok = (triad_rate >= float(acc["triad_rate_min"]))

    passed = bool(edges_ok and coverage_ok and triad_ok)

    # write metrics
    with open(os.path.join(outdir,"outputs/metrics","h23_edges.csv"),"w",newline="",encoding="utf-8") as f:
        w=csv.writer(f)
        w.writerow(["edge_i","edge_j","distance","contacts","rate"])
        for (i,j) in pairs:
            w.writerow([i,j,f"{dij[(i,j)]:.3f}",contacts[(i,j)],f"{contact_rate[f'{i}-{j}']:.6f}"])
        w.writerow([])
        w.writerow(["center","coverage_rate"])
        for i,cr in enumerate(coverage_rate):
            w.writerow([i,f"{cr:.6f}"])
        if len(pairs)==3:
            w.writerow([])
            w.writerow(["triad_rate",f"{triad_rate:.6f}"])

    audit={
        "sim":"H23_molecular_cohesion",
        "grid": M["grid"],
        "centers": centers_cfg,
        "H": H,
        "outer_margin": outer_margin,
        "band": M["band"],
        "control": M["control"],
        "pairs": [{"i":i,"j":j,"distance":dij[(i,j)]} for (i,j) in pairs],
        "coverage_rate": coverage_rate,
        "contact_rate": contact_rate,
        "triad_rate": triad_rate,
        "accept": acc,
        "passed": passed
    }
    dump_json(os.path.join(outdir,"outputs/audits","h23_audit.json"), audit)

    line = f"H23 PASS={passed} " + " ".join([f"e{i}-{j}={contact_rate[f'{i}-{j}']:.3f}" for (i,j) in pairs])
    if len(pairs)==3:
        line += f" triad={triad_rate:.3f}"
    write_text(os.path.join(outdir,"outputs/run_info","result_line.txt"), line)
    print(line)
    return audit

def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--manifest",required=True)
    ap.add_argument("--outdir",required=True)
    args=ap.parse_args()

    ensure_dirs(args.outdir,"config","outputs/metrics","outputs/audits","outputs/run_info","logs")
    with open(args.manifest,"r",encoding="utf-8") as f:
        M=json.load(f)
    # persist manifest
    dump_json(os.path.join(args.outdir,"config","manifest_h23.json"), M)
    # log env
    write_text(os.path.join(args.outdir,"logs","env.txt"),
               f"utc={utc_timestamp()}\nos={os.name}\npython={sys.version.split()[0]}\n")
    run_sim(M, args.outdir)

if __name__ == "__main__":
    main()
